{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "82047528",
   "metadata": {},
   "source": [
    "# NeuralCompression Flop Counter Example\n",
    "\n",
    "Welcome! In this notebook we'll walkthrough using `neuralcompression`'s flop counter to calculate the computational complexity of PyTorch models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0792c8f",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, if you haven't yet intalled `neuralcompression`, do so now with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e3f182",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/facebookresearch/NeuralCompression/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "24a69e05",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "import neuralcompression.functional as ncF \n",
    "from neuralcompression.models import ScaleHyperprior"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "571a539d",
   "metadata": {},
   "source": [
    "## Basic Usage\n",
    "\n",
    "The `neuralcompression` flop counter can be found at `neuralcompression.functional.count_flops`. To get started with the flop counter, two arguments need to be passed:\n",
    "\n",
    "- The model (an `nn.Module`) that you want to evaluate.\n",
    "- A list of inputs that the model should be evaluated on.\n",
    "\n",
    "As an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0dc3f2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "    nn.Conv2d(3, 16, kernel_size = 3, padding = 1),\n",
    "    nn.BatchNorm2d(16),\n",
    "    nn.ReLU(),\n",
    "    nn.Conv2d(16, 32, stride = 2, kernel_size = 5, padding = 2),\n",
    "    nn.BatchNorm2d(32),\n",
    "    nn.ReLU(),\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(32 * 16 * 16, 10)\n",
    ")\n",
    "\n",
    "inputs = [torch.randn(5, 3, 32, 32)]\n",
    "\n",
    "results = ncF.count_flops(model, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2be3fb9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(19619840,\n",
       " {'conv': 18595840, 'batch_norm': 614400, 'linear': 409600},\n",
       " Counter())"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91755cbc",
   "metadata": {},
   "source": [
    "The result returned by the flop counter is a 3-tuple:\n",
    "\n",
    "1. The first element records the total number of flops performed by the model.\n",
    "2. The second element in the return tuple breaks down this count by operation.\n",
    "3. The third item returned by the counter records all the operations that the counter didn't know how to count flops for (more detail about these ops below). This dicitionary (or more specifically a [collections.Counter](https://docs.python.org/3/library/collections.html#collections.Counter)) maps unknown op names to the number of times those ops were called in the model. In our example, all of the model's operations are supported by the counter, so this counter is empty.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "170ea68b",
   "metadata": {},
   "source": [
    "Most common ML operations (e.g. matrix multiplications, convolutions, elementwise arithmetic operations, etc.) already have counter functions registered by either `neuralcompression` or [fvcore](https://github.com/facebookresearch/fvcore), whose flop counter is used under the hood in `neuralcompression`. Other ops, like sigmoid and tanh, whose flop count can vary by platform, are by default ignored, but have estimated counter functions that can be optionally turned on (see the Single-Flop Estimate section for details). \n",
    "\n",
    "If you need to add support for more ops or override the counter's default implementation, read on!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0f73116",
   "metadata": {},
   "source": [
    "## Advanced Usage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a62de866",
   "metadata": {},
   "source": [
    "### How the Counter Works\n",
    "\n",
    "The flop counter in `neuralcompression` makes heavy use of the counting utilities in [fvcore](https://github.com/facebookresearch/fvcore). Counting a model's flops is a two-step process:\n",
    "\n",
    "1. Using PyTorch's [TorchScript](https://pytorch.org/docs/stable/jit.html) capabilities, the model is first JIT-traced into a computational graph. Each node in the graph corresponds to an ATen (linear algebra) operation, like matrix multiplications, convolutions, and elementwise operations like additions/subtractions.\n",
    "\n",
    "2. Every node in the traced graph is iterated over, and if a node is associated with a registered flop-counting function, that function is invoked and the counted flops are added to the model's total.\n",
    "\n",
    "\n",
    "In this second step, the counter's registered flop functions are stored as a dictionary mapping operator names (e.g. `aten::add`, `aten::matmul`) to counter functions. These functions have a signature of:\n",
    "\n",
    "```\n",
    "def my_counter_function(inputs: List[torch._C.Value], outputs: List[torch._C.Value]) -> float\n",
    "```\n",
    "\n",
    "where objects of type `torch._C.Value` represent the symbolic inputs and outputs for the node in the graph (i.e. they're not actual concrete `Tensor`s). Check out the `fvcore` codebase for [examples of how to write counter functions](https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/jit_handles.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317c09bd",
   "metadata": {},
   "source": [
    "### What is Counted as a Flop\n",
    "\n",
    "Note that all of the built-in counter functions consider one fused multiply-accumulate (MAC) to be one flop. This means that a matrix multiplication between an `N x K` and a `K x M` matrix is recorded as `N * K * M` flops (since `N * K * M` MACs are performed). For non-fused elementwise arithmetic operations such as addition, subtraction, absolute value, etc., one flop is counted for each element of the output tensor. For example, adding two `N x K` matrices counts as `N * K` flops."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2892a797",
   "metadata": {},
   "source": [
    "### Single-Flop Estimates\n",
    "\n",
    "As mentioned in the previous section, simple elementwise operations like addition and multiplication are counted as one flop per output tensor element. More complicated elementwise operations, such as logarithms, exponentiation, and the sigmoid function, are by default unregistered by the flop counter (i.e. these ops do not contribute to the total flop count and are returned as unrecognized ops), since the number of flops performed by these operations can vary by platform and implementation. \n",
    "\n",
    "If your model contains lots of these kinds of ops and you wish to obtain a rough estimate of their contribution to model complexity, for convenience the counter exposes the flag `use_single_flop_estimates`, which defaults to `False`. If `True`, this flag treats these more complicated elementwise operations like simple arithmetic ops and registers counter functions that count one flop per element of the output tensor. This is likely an undercount of the contribution of these operations, but this undercount may be more informative when analyzing a model than simply ignoring the operations altogether.\n",
    "\n",
    "If you know exactly how many flops some of these ops should have on your platform, use the `counter_overrides` argument described in the next section.\n",
    "\n",
    "Using the counter's default supported ops and the added estimate counters, complicated models should be able to pass through the flop counter with few to no unregistered ops, such as the [Scale Hyperprior Model](https://arxiv.org/abs/1802.01436):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "72eaddf7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "big_model = ScaleHyperprior(network_channels=32, compression_channels=64)\n",
    "_,_, unsupported_ops = ncF.count_flops(big_model, (torch.randn(1, 3, 64, 64),), use_single_flop_estimates=True)\n",
    "\n",
    "len(unsupported_ops) == 0 # Verifying no unsupported operations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52f36115",
   "metadata": {},
   "source": [
    "### Customizing the Counter\n",
    "\n",
    "To register additional counter functions or override the counter's default implementation for specific operations, use the `counter_overrides` argument. This argument takes the form of a dictionary mapping ATen operator names to the corresponding counter functions you wish to use. As an example, consider the following simple module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aa3b8c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModule(nn.Module):\n",
    "    def forward(self,x,y,z):\n",
    "        return (x + y).abs()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdb5c1db",
   "metadata": {},
   "source": [
    "By default, the elementwise addition of two tensors of shape `N x M` or taking the absolute vaue of an `N x M` tensor will contribute `N x M` flops to the total computational complexity of a model (since each scalar addition/absolute value is counted as 1 flop):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "66b78a26",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = 5\n",
    "M = 32\n",
    "\n",
    "inp = torch.randn(N, M)\n",
    "\n",
    "flops,_,_ = ncF.count_flops(MyModule(), (inp, inp, inp))\n",
    "flops == 2 * N * M"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea4d5521",
   "metadata": {},
   "source": [
    "However, let's say that you wanted to ignore all the flops coming from addition operations. You could do this as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5f7a806c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_flops, _, _ = ncF.count_flops(\n",
    "    MyModule(), \n",
    "    (inp, inp, inp), \n",
    "    counter_overrides = {\n",
    "        \"aten::add\": lambda inps,outs: 0.0\n",
    "    }\n",
    ")\n",
    "\n",
    "new_flops == N * M"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c33a317",
   "metadata": {},
   "source": [
    "and now we see that only the `N x M` flops from the absolute value, and not the additional `N x M` flops from the addition, are counted."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9087685",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "For more details on the internals of the flop counter and how to write counter functions, check out [fvcore on GitHub](https://github.com/facebookresearch/fvcore/tree/master/fvcore/nn). Happy counting!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5cb945",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
